{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate class-slice indexing table for experiments\n",
    "\n",
    "\n",
    "### Overview\n",
    "\n",
    "This is for experiment setting up for simulating few-shot image segmentation scenarios\n",
    "\n",
    "Input: pre-processed images and their ground-truth labels\n",
    "\n",
    "Output: a `json` file for class-slice indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n",
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%reset\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "import os\n",
    "import glob\n",
    "import SimpleITK as sitk\n",
    "import sys\n",
    "import json\n",
    "sys.path.insert(0, '../../dataloaders/')\n",
    "import niftiio as nio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "IMG_BNAME=\"./chaos_MR_T2_normalized/image_*.nii.gz\"\n",
    "SEG_BNAME=\"./chaos_MR_T2_normalized/label_*.nii.gz\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = glob.glob(IMG_BNAME)\n",
    "segs = glob.glob(SEG_BNAME)\n",
    "imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0])  ) ]\n",
    "segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split(\"_\")[-1].split(\".nii.gz\")[0])  ) ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./chaos_MR_T2_normalized/image_1.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_2.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_3.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_5.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_8.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_10.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_13.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_15.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_19.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_20.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_21.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_22.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_31.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_32.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_33.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_34.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_36.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_37.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_38.nii.gz',\n",
       " './chaos_MR_T2_normalized/image_39.nii.gz']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./chaos_MR_T2_normalized/label_1.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_2.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_3.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_5.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_8.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_10.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_13.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_15.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_19.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_20.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_21.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_22.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_31.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_32.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_33.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_34.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_36.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_37.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_38.nii.gz',\n",
       " './chaos_MR_T2_normalized/label_39.nii.gz']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "segs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pid 1 finished!\n",
      "pid 2 finished!\n",
      "pid 3 finished!\n",
      "pid 5 finished!\n",
      "pid 8 finished!\n",
      "pid 10 finished!\n",
      "pid 13 finished!\n",
      "pid 15 finished!\n",
      "pid 19 finished!\n",
      "pid 20 finished!\n",
      "pid 21 finished!\n",
      "pid 22 finished!\n",
      "pid 31 finished!\n",
      "pid 32 finished!\n",
      "pid 33 finished!\n",
      "pid 34 finished!\n",
      "pid 36 finished!\n",
      "pid 37 finished!\n",
      "pid 38 finished!\n",
      "pid 39 finished!\n"
     ]
    }
   ],
   "source": [
    "classmap = {}\n",
    "LABEL_NAME = [\"BG\", \"LIVER\", \"RK\", \"LK\", \"SPLEEN\"]     \n",
    "\n",
    "\n",
    "MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training\n",
    "\n",
    "fid = f'./chaos_MR_T2_normalized/classmap_{MIN_TP}.json' # name of the output file. \n",
    "for _lb in LABEL_NAME:\n",
    "    classmap[_lb] = {}\n",
    "    for _sid in segs:\n",
    "        pid = _sid.split(\"_\")[-1].split(\".nii.gz\")[0]\n",
    "        classmap[_lb][pid] = []\n",
    "\n",
    "for seg in segs:\n",
    "    pid = seg.split(\"_\")[-1].split(\".nii.gz\")[0]\n",
    "    lb_vol = nio.read_nii_bysitk(seg)\n",
    "    n_slice = lb_vol.shape[0]\n",
    "    for slc in range(n_slice):\n",
    "        for cls in range(len(LABEL_NAME)):\n",
    "            if cls in lb_vol[slc, ...]:\n",
    "                if np.sum( lb_vol[slc, ...]) >= MIN_TP:\n",
    "                    classmap[LABEL_NAME[cls]][str(pid)].append(slc)\n",
    "    print(f'pid {str(pid)} finished!')\n",
    "    \n",
    "with open(fid, 'w') as fopen:\n",
    "    json.dump(classmap, fopen)\n",
    "    fopen.close()  \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(fid, 'w') as fopen:\n",
    "    json.dump(classmap, fopen)\n",
    "    fopen.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
